{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predict results using U-net model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Crop the original images to small pieces, in total there are 52 samples\n",
    "\n",
    "## Step 2: Read the crop images for one sample,  then upsample them to the shape for the network, in this example, the target size is 256 x 256\n",
    "\n",
    "## Step 3: Create prediction results of one sample \n",
    "\n",
    "## Step 4: Stitch it back to a whole images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_crop_write(read_path, write_path, N_pieces): \n",
    "    #images = []\n",
    "    for filename in os.listdir(read_path):\n",
    "        f = os.path.join(read_path,filename)\n",
    "        img =cv2.imread(f, cv2.IMREAD_GRAYSCALE)\n",
    "\n",
    "        I_crop_v  =  np.asarray(np.split(img, 32, axis = 0)) \n",
    "        # when split in this axis can be recovered by using I_crop_recover = I_crop_v.reshape(1024,1024)\n",
    "        I_crop_vv = np.asarray(np.split(I_crop_v, 32, axis = 2))\n",
    "\n",
    "        for i in range(0,int(sqrt(N_pieces))):\n",
    "            for j in range (0,int(sqrt(N_pieces))):\n",
    "                prefix= re.sub(\"\\D\", \"\", filename)\n",
    "                filename_new = write_path + '\\\\'+ prefix + '_'+ str(i)+'_'+str(j) +'.jpg'\n",
    "                cv2.imwrite(filename_new, I_crop_vv[j, i, :, :])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "def create_filename(N_slices, N_image_start, N_image_stop): \n",
    "    # input: \n",
    "    # N_slices -  number of slice per images\n",
    "    # N_image_start - index of the first image to read\n",
    "    # N_image_stop - index of the last image to read\n",
    "    \n",
    "    N_slices = int(sqrt(N_slices))\n",
    "    filename = []\n",
    "    for i in range(N_image_start, N_image_stop+1):\n",
    "        for j in range(0, N_slices):\n",
    "            for k in range(0, N_slices):                  \n",
    "                f = str(i)+'_'+str(j)+'_'+str(k)+'.jpg'\n",
    "                                   \n",
    "                filename.append(f)\n",
    "    #print(filename)\n",
    "    return (filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def testGenerator(test_path,N_slices, N_image_start, N_image_stop,target_size = (256,256),flag_multi_class = False,as_gray = True):\n",
    "    # input:\n",
    "    # test_path: path of the images to read\n",
    "    # N_slices -  number of slice per images\n",
    "    # N_image_start - index of the first image to read\n",
    "    # N_image_stop - index of the last image to read \n",
    "    # target_size - shape of the images for output\n",
    "    file_list = create_filename(N_slices, N_image_start, N_image_stop)\n",
    "    for filename in file_list:\n",
    "        f = os.path.join(test_path,filename)\n",
    "        img =io.imread(f, as_gray = as_gray)\n",
    "\n",
    "        img = img / 255\n",
    "        img = trans.resize(img,target_size)\n",
    "        img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img\n",
    "        img = np.reshape(img,(1,)+img.shape)\n",
    "        yield img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def saveResult(save_path,npyfile,flag_multi_class = False,num_class = 2):\n",
    "    for i,item in enumerate(npyfile):\n",
    "        img = labelVisualize(num_class,COLOR_DICT,item) if flag_multi_class else item[:,:,0]\n",
    "        io.imsave(os.path.join(save_path,\"%d_predict.jpg\"%i),img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stitch_image (images, slices):\n",
    "    block_row = int(sqrt(slices))\n",
    "    for j in range(0,block_row):\n",
    "        output = images[j*block_row,:,:]\n",
    "        for i in range(1,block_row):\n",
    "            output = np.block([output, images[j*block_row+i,:,:]])\n",
    "        if j == 0:\n",
    "            output_whole=output\n",
    "        else:\n",
    "            output_whole = np.block([[output_whole],[output]])\n",
    "    return (output_whole)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Step 1: Crop the 1024 x 1024 images to 1024 pieces, each piece is 32 x 32 and Save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Path_full_size = r'C:\\Users\\yizhe\\Desktop\\Velux\\data\\Change_plane\\Full_size'\n",
    "Path_crop = r'C:\\Users\\yizhe\\Desktop\\Velux\\python\\NN\\Data\\Predict_stitch\\Crop_1024\\image'\n",
    "Predict_crop = r'C:\\Users\\yizhe\\Desktop\\Velux\\python\\NN\\Data\\Predict_stitch\\Crop_1024\\Predict'\n",
    "N_pieces = 1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images = read_crop_write(Path_full_size, Path_crop, N_pieces)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Step 2: read the selected samples to prepare for prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_slices = N_pieces # for some reason it is not 1024, it only has 1 - 31\n",
    "N_image_start = 2\n",
    "N_image_stop = N_image_start\n",
    "\n",
    "testGene = testGenerator(Path_crop,N_slices, N_image_start, N_image_stop)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Step 3: predict the image patch and Step 4: stitch the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import *\n",
    "\n",
    "model = unet()\n",
    "model.load_weights(\"unet_membrane.hdf5\")\n",
    "results = model.predict_generator(testGene, N_pieces, verbose = 1)\n",
    "#saveResult(Predict_crop,results)\n",
    "images = np.squeeze(results)\n",
    "#np.shape(images)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "whole_img = stitch_image (images, N_pieces)\n",
    "whole_img = trans.resize(whole_img,[1024, 1024])\n",
    "\n",
    "#plt.imshow(whole_img)\n",
    "\n",
    "path = r'C:\\Users\\yizhe\\Desktop\\Velux\\python\\NN\\Data\\Predict_stitch\\Crop_1024\\Stitch'\n",
    "filename = path + '\\\\' + str(N_image_start)+'.jpg'\n",
    "\n",
    "\n",
    "cv2.imwrite(filename,whole_img*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a loop to predict all the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(3, 52):\n",
    "    N_image_start = i\n",
    "    N_image_stop = N_image_start\n",
    "    testGene = testGenerator(Path_crop,N_slices, N_image_start, N_image_stop)\n",
    "    results = model.predict_generator(testGene, N_pieces, verbose = 1)\n",
    "    images = np.squeeze(results)\n",
    "    whole_img = stitch_image (images, N_pieces)\n",
    "    whole_img = trans.resize(whole_img,[1024, 1024])\n",
    "    \n",
    "    filename = path + '\\\\' + str(N_image_start)+'.jpg'\n",
    "    cv2.imwrite(filename,whole_img*100)\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
